Skip to content

[bugfix] avoid attention padding tokens computation in pcg#17706

Merged
ispobock merged 43 commits intosgl-project:mainfrom
Chen-0210:fix-qwen3-next
Apr 14, 2026
Merged

[bugfix] avoid attention padding tokens computation in pcg#17706
ispobock merged 43 commits intosgl-project:mainfrom
Chen-0210:fix-qwen3-next

Conversation

@Chen-0210
Copy link
Copy Markdown
Contributor

@Chen-0210 Chen-0210 commented Jan 25, 2026

Motivation

When PCG is enabled, the attention metadata is initialized with real_num_tokens, but the input tensor still contains padded tokens. Attention backends such as FlashInfer can not handle this well, which can lead to undefined behavior, including nan value, corrupted outputs (repeated !!!!!), result in abnormally long output lengths.

To fix this, exclude the padded tokens and make PCG more robust.

Modifications

  1. Reverts PR fix: piecewise_cuda_graph get correct qo_indptr #21452
  2. Slices tensors to real_num_tokens before/after the attention forward call — in unified_attention_with_output and unified_linear_attention_with_output, query/key/value and out_cache_loc are narrowed to exclude padding tokens before calling the backend, then restored out_cache_loc afterward, so attention only ever sees real tokens.

Accuracy Tests

Tested in CI

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Chen-0210, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a bug fix to prevent padding tokens from being included in attention calculations within the Piecewise CUDA Graph (PCG) execution. By introducing a real_num_tokens field to ForwardBatch and modifying attention layers to process only the actual tokens, the change ensures more accurate and efficient attention computations, eliminating the influence of padding on model outputs.

Highlights

  • Padding Token Exclusion: Attention mechanisms in radix_attention.py and qwen3_next.py now explicitly ignore padding tokens by slicing input tensors (query, key, value, hidden_states, rope, sinks) based on a new real_num_tokens field, ensuring computations only involve actual data.
  • Introduction of real_num_tokens: A new optional integer field, real_num_tokens, has been added to the ForwardBatch class to accurately track the number of non-padding tokens in a batch.
  • Piecewise CUDA Graph Runner Refactoring: The piecewise_cuda_graph_runner.py has been updated to correctly pass and utilize real_num_tokens during graph warmup and capture. It also streamlines the handling of out_cache_loc and out_cache_loc_swa by making them dynamically created or passed via ForwardBatch rather than persistent graph runner attributes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a bugfix for handling padding tokens within the Piecewise CUDA Graph (PCG) execution path. The core change is the addition of a real_num_tokens field to the ForwardBatch dataclass, which allows distinguishing actual tokens from padding. This field is then utilized in custom attention operations (unified_attention_with_output and gdn_with_output) to correctly slice input tensors, ensuring that only real tokens are processed during attention computation. Consequently, the logic in PiecewiseCudaGraphRunner for handling out_cache_loc has been simplified by removing pre-allocated tensors and passing them directly from the forward_batch.

The changes appear correct and effectively address the padding issue in PCG mode. I have one minor suggestion to improve code clarity by correcting a duplicated comment.

Comment thread python/sglang/srt/model_executor/forward_batch_info.py Outdated
@Chen-0210 Chen-0210 changed the title [bugfix] ignore padding tokens for attention in PCG [bugfix] avoid attention padding tokens computation in pcg Jan 25, 2026
Comment thread python/sglang/srt/models/qwen3_next.py Outdated
Comment thread python/sglang/srt/model_executor/forward_batch_info.py Outdated
Copy link
Copy Markdown
Collaborator

@Oasis-Git Oasis-Git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these changes affect a broad range of functionality, the full unit test suite should be carefully validated before approving this PR.

Comment thread python/sglang/srt/layers/radix_attention.py Outdated
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

2 similar comments
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

4 similar comments
@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Chen-0210
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Copy link
Copy Markdown
Collaborator

@Oasis-Git Oasis-Git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a clean change. approved

@ispobock
Copy link
Copy Markdown
Collaborator

@Chen-0210 could you fix the test conflict?

@Chen-0210
Copy link
Copy Markdown
Contributor Author

@Chen-0210 could you fix the test conflict?

Done

@ispobock ispobock merged commit 6760c79 into sgl-project:main Apr 14, 2026
35 of 74 checks passed
@Chen-0210 Chen-0210 deleted the fix-qwen3-next branch April 14, 2026 08:29
bingxche added a commit that referenced this pull request Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd blackwell SM100/SM120 deepseek dependencies Pull requests that update a dependency file deterministic Issues on deterministic inference/kernels diffusion SGLang Diffusion documentation Improvements or additions to documentation hicache Hierarchical Caching for SGLang model-gateway Multi-modal multi-modal language model npu piecewise-cuda-graph quant LLM Quantization run-ci sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants